#!/usr/bin/env python3
# Copyright 2023 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import json
from multiprocessing.pool import ThreadPool
import shlex
import shutil
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, List, Tuple
from impl.common import (
    CROSVM_ROOT,
    all_tracked_files,
    chdir,
    cmd,
    cwd_context,
    parallel,
    print_timing_info,
    quoted,
    record_time,
)

# List of globs matching files in the source tree required by tests at runtime.
# This is hard-coded specifically for crosvm tests.
TEST_DATA_FILES = [
    # Requried by nextest to obtain metadata
    "*.toml",
    # Configured by .cargo/config.toml to execute tests with the right emulator
    ".cargo/runner.py",
    # Requried by plugin tests
    "crosvm_plugin/crosvm.h",
    "tests/plugin.policy",
]

TEST_DATA_EXCLUDE = [
    # config.toml is configured for x86 hosts. We cannot use that for remote tests.
    ".cargo/config.toml",
]

cargo = cmd("cargo")
tar = cmd("tar")
rust_strip = cmd("rust-strip")


def collect_rust_libs():
    "Collect rust shared libraries required by the tests at runtime."
    lib_dir = Path(cmd("rustc --print=sysroot").stdout()) / "lib"
    for lib_file in lib_dir.glob("libstd-*"):
        yield (lib_file, Path("debug/deps") / lib_file.name)
    for lib_file in lib_dir.glob("libtest-*"):
        yield (lib_file, Path("debug/deps") / lib_file.name)


def collect_test_binaries(metadata: Any, strip: bool):
    "Collect all test binaries that are needed to run the tests."
    target_dir = Path(metadata["rust-build-meta"]["target-directory"])
    test_binaries = [
        Path(suite["binary-path"]).relative_to(target_dir)
        for suite in metadata["rust-binaries"].values()
    ]

    non_test_binaries = [
        Path(binary["path"])
        for crate in metadata["rust-build-meta"]["non-test-binaries"].values()
        for binary in crate
    ]

    def process_binary(binary_path: Path):
        source_path = target_dir / binary_path
        destination_path = binary_path
        if strip:
            stripped_path = source_path.with_suffix(".stripped")
            if (
                not stripped_path.exists()
                or source_path.stat().st_ctime > stripped_path.stat().st_ctime
            ):
                rust_strip(f"--strip-all {source_path} -o {stripped_path}").fg()
            return (stripped_path, destination_path)
        else:
            return (source_path, destination_path)

    # Parallelize rust_strip calls.
    pool = ThreadPool()
    return pool.map(process_binary, test_binaries + non_test_binaries)


def collect_test_data_files():
    "List additional files from the source tree that are required by tests at runtime."
    for file in all_tracked_files():
        for glob in TEST_DATA_FILES:
            if fnmatch(str(file), glob):
                if str(file) not in TEST_DATA_EXCLUDE:
                    yield (file, file)
                break


def collect_files(metadata: Any, output_directory: Path, strip_binaries: bool):
    # List all files we need as (source path, path in output_directory) tuples
    manifest: List[Tuple[Path, Path]] = [
        *collect_test_binaries(metadata, strip=strip_binaries),
        *collect_rust_libs(),
        *collect_test_data_files(),
    ]

    # Create all target directories
    for folder in set((output_directory / d).parent.resolve() for _, d in manifest):
        folder.mkdir(exist_ok=True, parents=True)

    # Use multiple processes of rsync copy the files fast (and only if needed)
    parallel(
        *(
            cmd("rsync -a", source, output_directory / destination)
            for source, destination in manifest
        )
    ).fg()


def generate_run_script(metadata: Any, output_directory: Path):
    # Generate metadata files for nextest
    binares_metadata_file = "binaries-metadata.json"
    (output_directory / binares_metadata_file).write_text(json.dumps(metadata))
    cargo_metadata_file = "cargo-metadata.json"
    cargo("metadata --format-version 1").write_to(output_directory / cargo_metadata_file)

    # Put together command line to run nextest
    run_cmd = [
        "cargo-nextest",
        "nextest",
        "run",
        f"--binaries-metadata={binares_metadata_file}",
        f"--cargo-metadata={cargo_metadata_file}",
        "--target-dir-remap=.",
        "--workspace-remap=.",
    ]
    command_line = [
        "#!/usr/bin/env bash",
        'cd "$(dirname "${BASH_SOURCE[0]}")" || die',
        f'{shlex.join(run_cmd)} "$@"',
    ]

    # Write command to a unix shell script
    shell_script = output_directory / "run.sh"
    shell_script.write_text("\n".join(command_line))
    shell_script.chmod(0o755)

    # TODO(denniskempin): Add an equivalent windows bash script


def generate_archive(output_directory: Path, output_archive: Path):
    with cwd_context(output_directory.parent):
        tar("-ca", output_directory.name, "-f", output_archive).fg()


def main():
    """
    Builds a package to execute tests remotely.

    ## Basic usage

    ```
    $ tools/nextest_package -o foo.tar.zst ... nextest args
    ```

    The archive will contain all necessary test binaries, required shared libraries and test data
    files required at runtime.
    A cargo nextest binary is included along with a `run.sh` script to invoke it with the required
    arguments. THe archive can be copied anywhere and executed:

    ```
    $ tar xaf foo.tar.zst && cd foo.tar.d && ./run.sh
    ```

    ## Nextest Arguments

    All additional arguments will be passed to `nextest list`. Additional arguments to `nextest run`
    can be passed to the `run.sh` invocation.

    For example:

    ```
    $ tools/nextest_package -d foo --tests
    $ cd foo && ./run.sh --test-threads=1
    ```

    Will only list and package integration tests (--tests) and run them with --test-threads=1.

    ## Stripping Symbols

    Debug symbols are stripped by default to reduce the package size. This can be disabled via
    the `--no-strip` argument.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--no-strip", action="store_true")
    parser.add_argument("--output-directory", "-d")
    parser.add_argument("--output-archive", "-o")
    parser.add_argument("--clean", action="store_true")
    parser.add_argument("--timing-info", action="store_true")
    (args, nextest_list_args) = parser.parse_known_args()
    chdir(CROSVM_ROOT)

    # Determine output archive / directory
    output_directory = Path(args.output_directory).resolve() if args.output_directory else None
    output_archive = Path(args.output_archive).resolve() if args.output_archive else None
    if not output_directory and output_archive:
        output_directory = output_archive.with_suffix(".d")
    if not output_directory:
        print("Must specify either --output-directory or --output-archive")
        return

    if args.clean and output_directory.exists():
        shutil.rmtree(output_directory)

    with record_time("Listing tests"):
        cargo(
            "nextest list",
            *(quoted(a) for a in nextest_list_args),
        ).fg()
    with record_time("Listing tests metadata"):
        metadata = cargo(
            "nextest list --list-type binaries-only --message-format json",
            *(quoted(a) for a in nextest_list_args),
        ).json()

    with record_time("Collecting files"):
        collect_files(metadata, output_directory, strip_binaries=not args.no_strip)
        generate_run_script(metadata, output_directory)

    if output_archive:
        with record_time("Generating archive"):
            generate_archive(output_directory, output_archive)

    if args.timing_info:
        print_timing_info()


if __name__ == "__main__":
    main()
